import pprint as pp
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import missingno as msno
import xgboost as xgb
import copy as cp
import geoplot as gplt
import geopandas as gpd
%matplotlib inline
sns.set_style("white")
from geoplot import crs as gcrs
from shapely.geometry import Point, Polygon
from itertools import chain, product, cycle
from collections import Counter, OrderedDict
from pandas.tseries.offsets import BDay
from pandas.tseries.holiday import *
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.externals import joblib
from tap import modelmetrics, modelparams
from IPython.display import Image, display, HTML
from keras.utils import plot_model
from keras.models import load_model
# Constants
params = modelparams.get_constants()
dirs = modelparams.get_data_dir()
min_max_scaler = MinMaxScaler()
standard_scaler = StandardScaler()
met = modelmetrics.metrics()
accident_severity_mapping = {
1: "Fatal",
2: "Severe",
3: "Slight"
}
day_of_week_mapping = {
1: "Sun",
2: "Mon",
3: "Tue",
4: "Wed",
5: "Thu",
6: "Fri",
7: "Sat"
}
month_mapping = {
1: "Jan",
2: "Feb",
3: "Mar",
4: "Apr",
5: "May",
6: "Jun",
7: "Jul",
8: "Aug",
9: "Sep",
10: "Oct",
11: "Nov",
12: "Dec"
}
one_hot_feature_mapping = joblib.load("./feature_mapping.oh.pkl.xz")
droppable_columns = np.array(["Longitude", "Latitude", "Date_Time", "Accident_Index",
"Accident_Severity", "Casualty_Severity",
'Location_Easting_OSGR', 'Location_Northing_OSGR'])
ordinal = np.array(["Accident_Severity", "Speed_Limit", "Age_Band_of_Casualty", "Number_of_Vehicles",
"Number_of_Casualties", 'Age_Band_of_Driver'])
acccat = np.array([
'Accident_Severity','Day_of_Week', 'Local_Authority_Highway', 'X1st_Road_Class', 'Road_Type',
'Junction_Detail', 'Junction_Control', 'X2nd_Road_Class', 'Pedestrian_Crossing_Human_Control',
'Pedestrian_Crossing_Physical_Facilities', 'Light_Conditions', 'Weather_Conditions', 'Road_Surface_Conditions',
'Special_Conditions_at_Site', 'Carriageway_Hazards', 'Urban_or_Rural_Area', 'Is_Holiday',
'Year', 'Month','Hour'
])
cascat = np.array([
'Casualty_Class', 'Sex_of_Casualty', 'Casualty_Severity', 'Pedestrian_Location', 'Pedestrian_Movement',
'Car_Passenger', 'Bus_or_Coach_Passenger','Pedestrian_Road_Maintenance_Worker', 'Casualty_Type',
'Casualty_Home_Area_Type', 'Casualty_IMD_Decile'
])
vehcat = np.array([
'Vehicle_Type', 'Towing_and_Articulation', 'Vehicle_Manoeuvre', 'Vehicle_Location_Restricted_Lane',
'Junction_Location', 'Skidding_and_Overturning', 'Hit_Object_in_Carriageway', 'Vehicle_Leaving_Carriageway',
'Hit_Object_off_Carriageway', 'X1st_Point_of_Impact','Was_Vehicle_Left_Hand_Drive', 'Sex_of_Driver',
'Engine_Capacity_CC', 'Propulsion_Code','Driver_IMD_Decile', 'Driver_Home_Area_Type', 'Vehicle_IMD_Decile'
])
# Classes and functions
# Plot a global mean horizontal line on a line plot for labels freater than min_label
# Otherwise, plot a histogram of counts per label on select_col
def plot_mean_relation_with(data, fig, select_col, nrow = -1, ncol = 3, ind = 1, exclusion_list = None, min_label = 5):
if not isinstance(select_col, str):
print('Only 1 feature can be selected')
return
if select_col not in data.columns:
print('Selected column must contain in the dataframe')
return
data_cols = data.columns.drop(select_col) if exclusion_list is None else data.columns.drop(exclusion_list + [select_col])
if nrow < 0:
nrow = np.ceil(len(data_cols) / ncol) + 2
global_mean = data[select_col].mean()
countplot_cols = list()
for col in data_cols:
if len(pd.unique(data[col])) > min_label:
ax = fig.add_subplot(nrow, ncol, ind)
ax.plot(data.loc[:, (col, select_col)].groupby(col).mean())
ax.axhline(global_mean, color = 'r', linestyle = '--')
else:
countplot_cols.append(col)
continue
ax.set_title(col.replace('_', ' '))
if len(data_cols) - len(countplot_cols) - ind <= ncol + len(countplot_cols):
ax.set_xlabel('Label')
if ind % ncol == 1:
ax.set_ylabel(select_col + ' (Mean)')
ind += 1
if (ind - 1) % ncol != 0:
ind += (ind % ncol) + 1
for col in countplot_cols:
ax = fig.add_subplot(nrow, ncol, ind)
sns.countplot(data = data, ax = ax, x = col, hue = select_col)
ind += 1
# Plot letterview for categorical variables. Similar to boxplot but for large datasets
def plot_lv_relation_with(data, fig, dv, hue = None, nrow = -1, ncol = 3, ind = 1, exclusion_list = None,
orient = 'v'):
if not isinstance(dv, str):
print('Independant variable must be single feature')
return
data_cols = data.columns.drop(dv) if exclusion_list is None else data.columns.drop(exclusion_list + [dv])
if nrow < 0:
nrow = np.ceil(len(data_cols) / ncol) + 2
for col in data_cols:
ax = fig.add_subplot(nrow, ncol, ind)
if hue is not None:
sns.lvplot(data = data, x = col, y = dv, hue = hue, ax = ax, orient = orient)
else:
sns.lvplot(data = data, x = col, y = dv, ax = ax, orient = orient)
if ind % ncol != 1:
ax.set_ylabel('')
ind += 1
sns.despine()
return ind
# Drop (after) sample not in (before)
# Execute this function after a sample has been dropped and consistency has to be achieved
def drop_records_after(before, after):
return after.drop(after.index[after.set_index('Accident_Index').index.isin(before.set_index('Accident_Index').index) ^ True],
axis=0)
# Create calender of GB Bank Holidays
class UKBankHolidays(AbstractHolidayCalendar):
rules = [
Holiday('New Years Day', month=1, day=1, observance=sunday_to_monday),
GoodFriday,
EasterMonday,
Holiday('Early May Bank', month=5, day=1, observance=DateOffset(weekday=MO(1))),
Holiday('Spring Bank', month=5, day=25, observance=DateOffset(weekday=MO(1))),
Holiday('Summer Bank', month=8, day=25, observance=DateOffset(weekday=MO(1))),
Holiday('Christmas Day', month=12, day=25, observance=sunday_to_monday),
Holiday('Boxing Day', month=12, day=26, observance=next_monday_or_tuesday)
]
# Date time general line plotter
def dt_general_plotter(ax, data, by, label, scaler = MinMaxScaler()):
return ax.plot(*zip(*[
(yr, cnt[0])
for yr, cnt in zip(range(min(by), max(by) + 1),
scaler.fit_transform(np.array(data.groupby(by).count()).reshape(-1, 1)))
]), label=label)
# Barplot of missing data for each feature with missing data
# Features without missing data are ignored
def plot_yearly_missing_data_cols(data, fig, nrow = -1, ncol = 3, ind = 1):
missing_cols = sorted([(x, col) for col, x in zip(data.columns, data.isnull().sum(axis=0).tolist()) if x > 0])
if nrow < 0:
nrow = np.ceil(len(missing_cols) / 3)
for (x, col) in missing_cols:
ax = fig.add_subplot(nrow, ncol, ind)
missing_group = data.loc[:, ('Year', col)].groupby('Year').apply(lambda x: x.loc[x[col].isnull(), 'Year'].count())
percent_missing = np.round((x/len(data)) * 100, 2)
ax.bar(x=missing_group.index, height=missing_group,
label='Total: ' + str(x) + " (" + (str(percent_missing) if percent_missing >= 0.01 else "0.01") + "%)")
ax.set_title(col.replace('_', ' '))
if len(missing_cols) - ind < ncol:
ax.set_xlabel('Year')
ax.set_xticks(sorted(data["Year"].unique()))
else:
ax.set_xticks([])
if ind % ncol == 1:
ax.set_ylabel('Frequency')
ax.legend()
ind += 1
# Plot Pearson correlation coefficient
def plot_pearsonr(data):
corr = data.copy().corr()
# Generate a mask for the upper triangle
mask = np.zeros_like(corr, dtype=np.bool)
mask[np.triu_indices_from(mask)] = True
# Set up the matplotlib figure
f, ax = plt.subplots(figsize=(20, 20))
# Generate a custom diverging colormap
cmap = sns.diverging_palette(220, 10, as_cmap=True)
# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(corr, mask=mask, cmap=cmap, vmax=.3, center=0,
square=True, linewidths=.5, cbar_kws={"shrink": .5})
# Temporarily drop features in droppable_column which are not used in modellng
def auto_drop(data):
mask = np.isin(droppable_columns, data.columns)
return data.drop(droppable_columns[mask], axis=1) if any(mask) else data
# Plot stacked barplot to visualise TAS absolute values
def plot_stacked_bar_severity(data, by, ax):
tmp = data.groupby([by, "Accident_Severity"]).count()["Accident_Index"]
slight = list(tmp[list(range(2, len(tmp), 3))])
severe = list(tmp[list(range(1, len(tmp), 3))])
fatal = list(tmp[list(range(0, len(tmp), 3))])
chrono = sorted(data[by].unique())
p1 = ax.bar(chrono, slight, label="Slight")
p2 = ax.bar(chrono, severe, bottom=slight, label="Severe")
p3 = ax.bar(chrono, fatal, bottom=np.add(slight, severe), color='#d62728', label="Fatal")
ax.set_ylabel('Number of Accidents')
ax.set_xlabel(by)
ax.set_xticks(chrono)
if by == "Day_of_Week":
ax.set_xticklabels([day_of_week_mapping[d] for d in chrono])
elif by == "Month":
ax.set_xticklabels([month_mapping[m] for m in chrono])
return [p1, p2, p3]
# Count all categorical features and plot barplot based on TAS
def count_all(data, catls, hue, fig, nrow = -1, ncol = 3, ind = 1,
excl_col=["Day_of_Week", "Year", "Month", "Hour", "Minute", "Accident_Severity"]):
hascat = catls[np.isin(catls, data.columns)]
if excl_col is not None:
hascat = hascat[np.isin(hascat, excl_col, invert=True)]
if nrow <= 0:
nrow = len(hascat)
for cat in hascat:
if len(data[cat].unique()) > 15:
continue # can't see anything
ax = fig.add_subplot(nrow, ncol, ind)
g = sns.countplot(x=cat, data=data, hue=hue, ax=ax)
if ind % ncol != 1:
ax.set_ylabel("")
ax.xaxis.set_ticklabels([int(val) for val in data[cat].values if not np.isnan(val)])
ind += 1
return ind
# Reverse one-hot encoded columns
def reverse_dummy(data, split_by="_"):
avail_cols = set()
for k, v in one_hot_feature_mapping.items():
if any(np.isin(v, data.columns)):
avail_cols.add(k)
get_cols = dict()
for k in avail_cols:
if len(one_hot_feature_mapping[k]) == 1:
get_cols[k] = data[k]
else:
get_cols[k] = data[one_hot_feature_mapping[k]].idxmax(axis=1)
get_cols = pd.DataFrame(get_cols)
for col in get_cols.columns:
if not np.issubdtype(get_cols[col], int):
tmp = get_cols[col].str.split("_", expand=True)
get_cols[col] = tmp[len(tmp.columns) - 1].astype("int64").copy()
return get_cols.copy()
# Get list of tuples for a simple numeric line plot
def simple_line_numeric(x, y):
return zip(*sorted(zip(x, y)))
# Plot confusion matrix as a heatmap
# Reference: http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py
def plot_confusion_matrix(cm, classes, ax, title='Confusion matrix', cmap=plt.cm.BuGn, shrink=0.35, apply_colorbar=True):
norm_cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
im = ax.imshow(norm_cm, interpolation='nearest', cmap=cmap)
if apply_colorbar:
ax.figure.colorbar(im, ax=ax, shrink=shrink)
ax.set_title(title)
tick_marks = np.arange(len(classes))
ax.set_xticks(tick_marks)
ax.set_xticklabels(classes, rotation=45)
ax.set_yticks(tick_marks)
ax.set_yticklabels(classes)
thresh = cm.max() / 2.
for i, j in product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, format(cm[i, j], "d") + "\n(" + format(norm_cm[i, j], ".2f") + ")",
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
ax.set_ylabel('True label')
ax.set_xlabel('Predicted label')
return im
# Convert data to geodataframe which contains "geometry" for locational plotting
# Overwrite geometry columns
def to_geodf(data, lon="Longitude", lat="Latitude", dropna=True):
data = data.copy().dropna(subset=[lon, lat])
return gpd.GeoDataFrame(data,
geometry=data.apply(lambda srs: Point(srs[lon], srs[lat]), axis='columns'))
# Convert data to geodataframe which contains "geometry" for polygon based plotting
# Overwrite geometry columns
def to_geodf_polygon(data, x="Location_Easting_OSGR", y="Location_Northing_OSGR"):
data = data.copy().dropna(subset=[x, y])
return gpd.GeoDataFrame(data,
geometry=data.apply(lambda srs: Polygon(srs[x], srs[y]), axis="columns"))
# Combine one-hot encoded features into normal columns and return feature importances in dictionary
def remap_feature_importances(cols, feature_importances, do_average=False, do_round=True):
fi = dict(
(colname, fimp) for colname, fimp in
[(c, f) for c, f in zip(cols, feature_importances)]
)
fi_cols = [c for c, _ in fi.items()]
fi_map = dict()
for k, v in one_hot_feature_mapping.items():
mask = np.isin(v, fi_cols)
if any(mask):
fi_map[k] = np.average(fi[mask_col] for mask_col in v[mask]) if do_average else np.sum(fi[mask_col] for mask_col in v[mask])
if do_round:
fi_map[k] = np.round(fi_map[k], 4)
return fi_map
# plot ROC graph given ROC data
def plot_roc_auc_curve(roc_data, ax, title, n_classes=3):
lw = 2
ax.plot(roc_data["fpr"]["micro"], roc_data["tpr"]["micro"],
label='micro-average ROC curve (AUC = {0:0.2f})'
''.format(roc_data["roc_auc"]["micro"]),
color='deeppink', linestyle=':', linewidth=4)
ax.plot(roc_data["fpr"]["macro"], roc_data["tpr"]["macro"],
label='macro-average ROC curve (AUC = {0:0.2f})'
''.format(roc_data["roc_auc"]["macro"]),
color='navy', linestyle=':', linewidth=4)
colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(n_classes), colors):
plt.plot(roc_data["fpr"][i], roc_data["tpr"][i], color=color, lw=lw,
label='ROC curve of class {0} (AUC = {1:0.2f})'
''.format(accident_severity_mapping[i + 1], roc_data["roc_auc"][i]))
ax.plot([0, 1], [0, 1], 'k--', lw=lw)
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title(title)
ax.legend(loc="lower right")
# plot OOB error rate for RF given RF OOB data
def plot_rf_oob(oob_data, ax, title):
min_estimators, max_estimators = -1, -1
for label, clf_err in oob_data.items():
xs, ys = zip(*clf_err)
min_estimators, max_estimators = np.min(xs), np.max(xs)
ax.plot(xs, ys, label=label)
ax.set_xlim(min_estimators, max_estimators)
ax.set_xlabel("Number of Estimators")
ax.set_ylabel("OOB Error Rate")
ax.set_title(title)
ax.legend(loc="upper right")
# Import processed original data
acc = pd.read_csv("./" + dirs["orig"] + 'acc2005_2016.proc.csv', parse_dates=["Date_Time"])
cas = pd.read_csv("./" + dirs["orig"] + 'cas2005_2016.proc.csv')
veh = pd.read_csv("./" + dirs["orig"] + 'veh2005_2016.proc.csv')
Imported processed dataset. Codes here are for archive.
# # Replace all -1 missing values to NA
# acc.replace(-1, np.nan, inplace=True)
# veh.replace(-1, np.nan, inplace=True)
# cas.replace(-1, np.nan, inplace=True)
# # Tidy up columns names
# acc.rename(lambda x: x.replace('.', ''), axis=1, inplace=True) # Inplace to ensure it changes the variable itself
# acc.rename(columns={
# 'Speed_limit': 'Speed_Limit',
# 'Pedestrian_CrossingHuman_Control': 'Pedestrian_Crossing_Human_Control',
# 'Pedestrian_CrossingPhysical_Facilities': 'Pedestrian_Crossing_Physical_Facilities'
# }, inplace=True)
# veh.rename(lambda x: x.replace('.', ''), axis=1, inplace=True)
# veh.rename(columns={
# 'Vehicle_LocationRestricted_Lane': 'Vehicle_Location_Restricted_Lane'
# }, inplace=True)
# # Date Time cannot be imputed since there is no way to "guess" the date an accident occur, remove associated records
# acc.dropna(subset=['Time', 'Speed_Limit'], inplace=True)
# acc['Speed_Limit'] = acc['Speed_Limit'].astype('int64')
# # We know that if the accident record doesn't exist, the vehicle record cannot exist. Therefore, remove vehicle records
# # that no longer have an associating accident
# veh.drop(veh.index[veh.set_index('Accident_Index').index.isin(acc.set_index('Accident_Index').index) ^ True],
# axis=0, inplace=True)
# # Was_Vehicle_Left_Hand_Drive feature cannot be imputed as it is part of the aim factor of the project
# veh_ind = list(set(veh.loc[veh['Was_Vehicle_Left_Hand_Drive'].isnull(), 'Accident_Index']))
# # Remove every record that has the feature with missing value alongside vehicles incolved with the same Accident_Index
# veh.drop(veh.index[veh.set_index('Accident_Index').index.isin(veh_ind) ^ False], axis=0, inplace=True)
# veh['Was_Vehicle_Left_Hand_Drive'] = veh['Was_Vehicle_Left_Hand_Drive'].astype('int64')
# # Remove every Accidents which is associated with the removed Vehicles
# acc.drop(acc.index[acc.set_index('Accident_Index').index.isin(veh.set_index('Accident_Index').index) ^ True],
# axis=0, inplace=True)
# # Remove noisy features
# acc.drop([
# # 'Location_Easting_OSGR', 'Location_Northing_OSGR',
# 'LSOA_of_Accident_Location',
# 'Local_Authority_District', # Local_Authority_Highway
# 'Police_Force', 'Did_Police_Officer_Attend_Scene_of_Accident', # Police involvement unrelated
# 'X1st_Road_Number', 'X2nd_Road_Number',
# # 'Longitude', 'Latitude', # Not used in modelling but useful in analysis
# ], axis=1, inplace=True)
# veh.drop([
# 'Vehicle_Reference', # Reference value not useful
# 'Age_of_Driver', # Age_Band_of_Driver more general
# 'Journey_Purpose_of_Driver' # Missing from 2005 - 2010
# ], axis=1, inplace=True)
# cas.drop([
# 'Age_of_Casualty', # Duplicate of Age_Band_of_Casualty
# 'Casualty_Reference', 'Vehicle_Reference' # Reference value not useful
# ], axis=1, inplace=True)
# acc.replace({
# 'Road_Type': {9: np.nan},
# 'Weather_Conditions': {9: np.nan}
# }, inplace=True)
# veh.replace({
# 'Sex_of_Driver': {3: np.nan}
# }, inplace=True)
# # Integer encode
# lah_le = LabelEncoder()
# acc[["Local_Authority_Highway"]] = lah_le.fit_transform(acc[["Local_Authority_Highway"]])
# # Combine Date and Time to utilise datetime utility from Pandas
# acc['Date_Time'] = pd.to_datetime(acc['Date'] + ' ' + acc['Time'], format='%d/%m/%Y %H:%M')
# acc.drop(['Date', 'Time'], axis=1, inplace=True)
# # Create a new feature is_holiday to indicate if the accident occured on a holiday
# ukhols = UKBankHolidays().holidays(min(acc['Date_Time']), max(acc['Date_Time']))
# acc['Is_Holiday'] = 0
# acc.loc[acc['Date_Time'].dt.date.apply(lambda x: x in ukhols), ('Is_Holiday')] = 1
# # Only Minute has no useful pattern as its an interval of 5 minutes.
# # Day is already included as acc.Day_of_Week
# acc['Year'], acc['Month'], acc['Hour'] = acc['Date_Time'].dt.year, acc['Date_Time'].dt.month, acc['Date_Time'].dt.hour
# # Odd or questionable records caught during analysis which are removed later are placed in this cell
# # Speed_Limit == 0 record, makes no sense to have any road with 0 speed limit
# acc.drop(acc[acc['Speed_Limit'] == 0].index, inplace=True)
# veh = drop_records_after(acc, veh)
###############END######################
Count the number of TAS levels in each dataset
fig = plt.figure(figsize=[20, 5])
ind = 131
ax = fig.add_subplot(ind)
x, y = simple_line_numeric(acc["Year"].unique(), list(acc.groupby("Year").count()["Accident_Index"]))
ax.bar(x, y)
ax.set_ylabel("Count")
ax.set_xlabel("Year")
ax.set_xticks(sorted(acc["Year"].unique()))
ax.set_title("Accidents")
ind += 1
ax = fig.add_subplot(ind)
x, y = simple_line_numeric(acc["Year"].unique(),
list(cas.merge(acc[["Accident_Index", "Year"]], how="left").groupby("Year").count()["Accident_Index"]))
ax.bar(x, y)
ax.set_xlabel("Year")
ax.set_xticks(sorted(acc["Year"].unique()))
ax.set_title("Casualties")
ind += 1
ax = fig.add_subplot(ind)
x, y = simple_line_numeric(acc["Year"].unique(),
list(veh.merge(acc[["Accident_Index", "Year"]], how="left").groupby("Year").count()["Accident_Index"]))
ax.bar(x, y)
ax.set_xlabel("Year")
ax.set_xticks(sorted(acc["Year"].unique()))
ax.set_title("Vehicles")
ind += 1
fig.suptitle("Count of Samples for Each Dataset")
plt.show()
Analyse general trend
veh_ind = pd.DataFrame({
'Accident_Index': veh['Accident_Index']
}).merge(acc.loc[:, ('Accident_Index', 'Date_Time', 'Day_of_Week')], on = 'Accident_Index', how = 'inner')
cas_ind = pd.DataFrame({
'Accident_Index': cas['Accident_Index']
}).merge(acc.loc[:, ('Accident_Index', 'Date_Time', 'Day_of_Week')], on = 'Accident_Index', how = 'inner')
hold = 'Accident_Index'
dtm = 'Date_Time'
fig = plt.figure(figsize=[20,20])
# fig.set_facecolor("white")
nrow, ncol, ind = 3, 3, 1
# Year
ax = fig.add_subplot(nrow, ncol, ind)
dt_general_plotter(ax, acc[hold], acc[dtm].dt.year, 'Accidents', min_max_scaler)
dt_general_plotter(ax, veh_ind[hold], veh_ind[dtm].dt.year, 'Vehicles', min_max_scaler)
dt_general_plotter(ax, cas_ind[hold], cas_ind[dtm].dt.year, 'Casualties', min_max_scaler)
ax.set_xticks(acc[dtm].dt.year.unique())
ax.set_xlabel("Year")
ax.set_ylabel("Normalised Number of Accidents")
ax.set_title('Number of Accidents per Year')
ind += 1
# Month
ax = fig.add_subplot(nrow, ncol, ind)
dt_general_plotter(ax, acc[hold], acc[dtm].dt.month, 'Accidents', min_max_scaler)
dt_general_plotter(ax, veh_ind[hold], veh_ind[dtm].dt.month, 'Vehicles', min_max_scaler)
dt_general_plotter(ax, cas_ind[hold], cas_ind[dtm].dt.month, 'Casualties', min_max_scaler)
ax.set_xticks(acc[dtm].dt.month.unique())
ax.set_xlabel("Month")
ax.set_xticks(sorted(acc["Month"].unique()))
ax.set_xticklabels([month_mapping[m] for m in sorted(acc["Month"].unique())])
ax.set_title('Number of Accidents per Month')
ax.set_yticks([])
ind += 1
# Day
ax = fig.add_subplot(nrow, ncol, ind)
dt_general_plotter(ax, acc[hold], acc[dtm].dt.day, 'Accidents', min_max_scaler)
dt_general_plotter(ax, veh_ind[hold], veh_ind[dtm].dt.day, 'Vehicles', min_max_scaler)
dt_general_plotter(ax, cas_ind[hold], cas_ind[dtm].dt.day, 'Casualties', min_max_scaler)
ax.set_title('Number of Accidents per Day')
ax.set_xlabel("Day")
ax.set_yticks([])
ind += 1
ax = fig.add_subplot(nrow, ncol, ind)
dt_general_plotter(ax, acc[hold], acc["Day_of_Week"], 'Accidents', min_max_scaler)
dt_general_plotter(ax, veh_ind[hold], veh_ind["Day_of_Week"], 'Vehicles', min_max_scaler)
dt_general_plotter(ax, cas_ind[hold], cas_ind["Day_of_Week"], 'Casualties', min_max_scaler)
ax.set_title('Number of Accidents in Each Day of a Week')
ax.set_xlabel("Day of Week")
ax.set_xticks(sorted(acc["Day_of_Week"].unique()))
ax.set_xticklabels(['Sun', 'Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat'])
ax.set_ylabel("Normalised Number of Accidents")
ind += 1
# Hour
ax = fig.add_subplot(nrow, ncol, ind)
dt_general_plotter(ax, acc[hold], acc[dtm].dt.hour, 'Accidents', min_max_scaler)
dt_general_plotter(ax, veh_ind[hold], veh_ind[dtm].dt.hour, 'Vehicles', min_max_scaler)
dt_general_plotter(ax, cas_ind[hold], cas_ind[dtm].dt.hour, 'Casualties', min_max_scaler)
ax.set_xticks(acc[dtm].dt.hour.unique())
ax.set_xlabel("Hour")
ax.set_yticks([])
ax.set_title('Number of Accidents per Hour')
ind += 1
# Minute
ax = fig.add_subplot(nrow, ncol, ind)
l1, = dt_general_plotter(ax, acc[hold], acc[dtm].dt.minute, 'Accidents', min_max_scaler)
l2, = dt_general_plotter(ax, veh_ind[hold], veh_ind[dtm].dt.minute, 'Vehicles', min_max_scaler)
l3, = dt_general_plotter(ax, cas_ind[hold], cas_ind[dtm].dt.minute, 'Casualties', min_max_scaler)
ax.set_title('Number of Accidents per Minute')
ax.set_xlabel("Minute")
ax.set_yticks([])
ind += 1
fig.legend([l1, l2, l3], ["Accidents", "Vehicles", "Casualties"], loc="center right", fancybox=True, fontsize="large")
plt.show()
Extend to identify TAS levels in in each year, month, day of a week and hours
fig = plt.figure(figsize=[13, 13])
fig.legend()
ind = 221
leg = list()
for c in ["Year", "Month", "Day_of_Week", "Hour"]:
ax = fig.add_subplot(ind)
leg = plot_stacked_bar_severity(acc, c, ax)
ind += 1
fig.suptitle("Frequency of Each Accident Severity")
fig.legend(leg, ["Slight", "Severe", "Fatal"], loc="center right", fancybox=True, fontsize="large", title="Accident Severity")
plt.show()
Discover if location is a factor in TAS
fig = plt.figure(figsize=[20, 20])
years = sorted(acc["Year"].unique())
ncol, ind = 3, 1
nrow = int(np.round(len(years)/ncol))
fig.suptitle("GB Accident Severity Distribution")
for yr in years:
geodata = to_geodf(acc[acc["Year"] == yr])
ax = fig.add_subplot(nrow, ncol, ind)
ax.set_title("Year=" + str(yr))
ax = gplt.pointplot(geodata, alpha=0.2, hue='Accident_Severity', categorical=True, legend_labels=["Fatal", "Severe", "Slight"],
legend=False if yr != 2006 else True, edgecolor='white', linewidth=0.5, legend_kwargs={'loc': 'upper left'}, ax=ax)
ind += 1
Get a closer look into TAS locational as it is quite hard to see from the previous geoplot
tas = sorted(acc["Accident_Severity"].unique())
yrs = sorted(acc["Year"].unique())
colors = {
1: "red",
2: "orange",
3: "grey"
}
for yr in yrs:
fig = plt.figure(figsize=[15, 5])
fig.suptitle("GB Accident Severity " + str(yr) + " Distribution")
ind = 131
tmp = acc[acc["Year"] == yr]
for t in tas:
geodata = to_geodf(tmp[tmp["Accident_Severity"] == t])
ax = fig.add_subplot(ind)
ax.set_title("Accident_Severity=" + accident_severity_mapping[t])
ax = gplt.pointplot(geodata, edgecolor='white', linewidth=0.5, ax=ax, color=colors[t])
ind += 1
Find out how the samples are distributed with respect to timeline
tmp = acc.merge(veh, on='Accident_Index').merge(cas, on='Accident_Index').groupby(
['Year', 'Month', 'Accident_Severity']).count()['Accident_Index']
ind = 231
fig = plt.figure(figsize=[17, 10])
fig.suptitle("Distribution of Number of Accident From 2005 to 2016")
ax = fig.add_subplot(ind)
for yr in pd.unique(acc['Year']):
sns.kdeplot(tmp[yr], label=str(yr), ax=ax)
ax.set_xlabel("Distribution of Accidents Through Year and Month")
ax.set_ylabel("Density")
ind += 1
ax = fig.add_subplot(ind)
sns.distplot(acc.groupby(['Year', acc['Accident_Severity']]).count()['Accident_Index'].astype("int64"), axlabel='Accident By Year', ax=ax)
ind += 1
ax = fig.add_subplot(ind)
sns.distplot(acc.groupby(['Year', 'Month', acc['Accident_Severity']]).count()['Accident_Index'].astype("int64"), axlabel='Accident By Month', ax=ax)
ind += 1
ax = fig.add_subplot(ind)
sns.distplot(acc.groupby(['Year', 'Month', acc['Date_Time'].dt.day, acc['Accident_Severity']]).count()['Accident_Index'].fillna(0).astype("int64"), axlabel='Accident By Day', ax=ax)
ax.set_ylabel("Density")
ind += 1
ax = fig.add_subplot(ind)
sns.distplot(acc.groupby(['Year', 'Month', acc['Date_Time'].dt.day, acc['Date_Time'].dt.hour, acc['Accident_Severity']]).count()['Accident_Index'].fillna(0).astype("int64"), axlabel='Accident By Hour', ax=ax)
ind += 1
ax = fig.add_subplot(ind)
sns.distplot(acc.groupby(['Year', 'Month', acc['Date_Time'].dt.day, acc['Date_Time'].dt.hour, acc['Date_Time'].dt.minute, acc['Accident_Severity']]).count()['Accident_Index'].fillna(0).astype("int64"), axlabel='Accident By Minute', ax=ax)
plt.show()
Look into ordinal features distribution
fig = plt.figure(figsize=[20, 12])
nrow, ncol, ind = 3, 3, 1
ind = plot_lv_relation_with(data = acc[acc.columns[acc.columns.isin(ordinal)]],
fig = fig, dv = 'Accident_Severity', orient = 'h', ncol = ncol, nrow = nrow, ind = ind)
tmp = cas.merge(acc[["Accident_Index", "Accident_Severity"]], how="left")
ind = plot_lv_relation_with(data = tmp[tmp.columns[tmp.columns.isin(ordinal)]],
fig = fig, dv = 'Accident_Severity', orient = 'h', ncol = ncol, nrow = nrow, ind = ind)
tmp = veh.merge(acc[["Accident_Index", "Accident_Severity"]], how="left")
ind = plot_lv_relation_with(data = tmp[tmp.columns[tmp.columns.isin(ordinal)]],
fig = fig, dv = 'Accident_Severity', orient = 'h', ncol = ncol, nrow = nrow, ind = ind)
fig.suptitle("Letterplot of Ordinal Features")
plt.show()
del tmp
Look into categorical features distribution
fig = plt.figure(figsize=[20, 170])
nrow = len(acccat) + len(vehcat) + len(cascat)
ind = count_all(acc, acccat, hue="Accident_Severity", fig=fig, nrow=nrow)
ind = count_all(cas.merge(acc[["Accident_Index", "Accident_Severity"]], how="left"), cascat, hue="Accident_Severity", fig=fig, nrow=nrow, ind=ind)
ind = count_all(veh.merge(acc[["Accident_Index", "Accident_Severity"]], how="left"), cascat, hue="Accident_Severity", fig=fig, nrow=nrow, ind=ind)
###############END######################
Analyse missing data correlation between features
msno.heatmap(auto_drop(acc.merge(veh, on='Accident_Index', how='inner').merge(cas, on='Accident_Index', how='inner')), cmap='BuPu')
Look into cluster distance between features for possible patterns in missingness
msno.dendrogram(auto_drop(acc.merge(veh, on = 'Accident_Index', how = 'inner').merge(cas, on='Accident_Index', how='inner')))
Get absolute numbers on missing data
# How many are missing each year? Maybe survey is different by year
plot_yearly_missing_data_cols(data = auto_drop(acc), fig = plt.figure(figsize=[20,15]))
plt.show()
plot_yearly_missing_data_cols(data = auto_drop(veh.merge(acc.loc[:, ('Accident_Index', 'Year')], how = 'left')),
fig = plt.figure(figsize=[20, 25]))
plt.show()
plot_yearly_missing_data_cols(data = auto_drop(cas.merge(acc.loc[:, ('Accident_Index', 'Year')], on = 'Accident_Index', how = 'inner')),
fig = plt.figure(figsize=[20, 10]))
plt.show()
# Will import inputed dataset next stage
# # Remove features with > 10% missing data
# # The rest will be imputed with R Mice
# acc.drop(['Junction_Control', 'X2nd_Road_Class'], axis=1, inplace=True)
# veh.drop(['Age_Band_of_Driver', 'Engine_Capacity_CC', 'Propulsion_Code', 'Age_of_Vehicle', 'Driver_IMD_Decile',
# 'Driver_Home_Area_Type', 'Vehicle_IMD_Decile'], axis=1, inplace=True)
# cas.drop(['Casualty_Home_Area_Type', 'Pedestrian_Road_Maintenance_Worker', 'Casualty_IMD_Decile'], axis=1, inplace=True)
# # Drop year 2016 data as it has unusually high missing value in unfiltered feature set
# acc.drop(acc.index[acc['Year'] == 2016], inplace=True)
# veh = drop_records_after(acc, veh)
# cas = drop_records_after(acc, cas)
# # Sex of Driver according to dendrogram is MNAR as it has no relationship with other features,
# # drop the records which have it missing
# veh.dropna(subset=['Sex_of_Driver'], inplace=True)
# acc = drop_records_after(veh, acc)
# cas.dropna(subset=['Sex_of_Casualty'], inplace=True)
###############END######################
# import data
iacc = pd.read_csv("./" + dirs["imp"] + "acc2005_2015-v2018.2.5.imp.csv").merge(acc[["Accident_Index", "Number_of_Casualties", "Local_Authority_Highway"]], how="left")
iveh = pd.read_csv("./" + dirs["imp"] + "veh2005_2015-v2018.2.5.imp.csv")
icas = pd.read_csv("./" + dirs["imp"] + "cas2005_2015-v2018.2.5.imp.csv")
Find out imputed dataset categorical distribution and compare with without imputation
fig = plt.figure(figsize=[20, 20])
fig.suptitle("Accident Severity Count for Each Categorical Feature")
nrow = 6
ind = count_all(iacc, acccat, hue="Accident_Severity", fig=fig, nrow=nrow)
ind = count_all(icas.merge(iacc[["Accident_Index", "Accident_Severity"]], how="left"), cascat, hue="Accident_Severity", fig=fig, nrow=nrow, ind=ind)
ind = count_all(iveh.merge(iacc[["Accident_Index", "Accident_Severity"]], how="left"), cascat, hue="Accident_Severity", fig=fig, nrow=nrow, ind=ind)
Now look into ordinal features
fig = plt.figure(figsize=[20, 12])
nrow, ncol, ind = 3, 3, 1
ind = plot_lv_relation_with(data = iacc[iacc.columns[iacc.columns.isin(ordinal)]],
fig = fig, dv = 'Accident_Severity', orient = 'h', ncol = ncol, nrow = nrow, ind = ind)
tmp = icas.merge(iacc[["Accident_Index", "Accident_Severity"]], how="left")
ind = plot_lv_relation_with(data = tmp[tmp.columns[tmp.columns.isin(ordinal)]],
fig = fig, dv = 'Accident_Severity', orient = 'h', ncol = ncol, nrow = nrow, ind = ind)
tmp = iveh.merge(iacc[["Accident_Index", "Accident_Severity"]], how="left")
ind = plot_lv_relation_with(data = tmp[tmp.columns[tmp.columns.isin(ordinal)]],
fig = fig, dv = 'Accident_Severity', orient = 'h', ncol = ncol, nrow = nrow, ind = ind)
fig.suptitle("Letterplot of Ordinal Features")
plt.show()
del tmp
##############END########################
Analyse GBT feature selection ranking
tmp = joblib.load("./plots/xgb_feature_importance.cas.pkl.xz")
f = dict()
for k, v in tmp.items():
f[k] = int(np.round(v))
fig = plt.figure(figsize=[20, 20])
ax = fig.add_subplot(211)
xgb.plot_importance(f, ax=ax, title="Feature Importance for Casualty Inclusive Features")
tmp = joblib.load("./plots/xgb_feature_importance.nocas.pkl.xz")
f = dict()
for k, v in tmp["xgtf"].items():
f[k] = int(np.round(v))
ax = fig.add_subplot(212)
xgb.plot_importance(f, ax=ax, title="Feature Importance for No Casualty Features")
plt.show()
del f, tmp
baseline_ratio = np.divide(tap["Accident_Severity"].value_counts(), tap["Accident_Severity"].value_counts().sum())
every_year_num_of_acc = tap.groupby("Year").count()["Accident_Index"]
ratio_store = OrderedDict([
(sev, []) for sev in sorted(pd.unique(tap["Accident_Severity"]))
])
for sev, _ in ratio_store.items():
severity_count = tap[tap["Accident_Severity"] == sev].groupby("Year").count()["Accident_Index"]
ratio_store[sev] = np.divide(severity_count, every_year_num_of_acc)
fig = plt.figure(figsize=[20,5])
ind = 131
for sev, ratio in ratio_store.items():
ax = fig.add_subplot(ind)
ax.bar(sorted(pd.unique(tap["Year"])), ratio)
ax.axhline(baseline_ratio[sev], color="red")
ax.set_title("Accident_Severity=" + accident_severity_mapping[sev])
ax.set_xticks(sorted(pd.unique(tap["Year"])))
ax.set_xlabel("Year")
if ind == 131:
ax.set_ylabel("Ratio to Total Accidents")
ind += 1
sampled = joblib.load("test/stratified_traintest.pkl.xz")
ignore_cols = ['Accident_Index', 'Date_Time', 'Year']
sel_cols = ['Accident_Severity', 'Number_of_Vehicles', 'Number_of_Casualties']
# Only interested in training data
df_sampled = pd.DataFrame(sampled["X"].copy())
df_sampled["Accident_Severity"] = sampled["Y"].copy()
_, x_sample, _, y_sample = train_test_split(df_sampled.drop("Accident_Severity", axis=1), df_sampled["Accident_Severity"],
test_size=0.2, stratify=df_sampled["Accident_Severity"])
geo_df_sampled = df_sampled.drop_duplicates(subset=["Accident_Index"]).merge(acc[["Accident_Index", "Longitude", "Latitude"]], how="left", on="Accident_Index")
del sampled
g = sns.pairplot(x_sample.join(y_sample)[sel_cols], hue="Accident_Severity")
new_labels = ['Fatal', 'Severe', 'Slight']
for t, l in zip(g._legend.texts, new_labels): t.set_text(l)
cas_sample = pd.read_csv("./train/stratified_XY_train.oh.tlsmote.cas.v2.csv")
sample_cols = joblib.load("./train/stratified_XY_train.oh.tlsmote.cols.pkl.xz")
rev_cas_sample = reverse_dummy(cas_sample)
rev_cas_sample["Accident_Severity"] = cas_sample["Accident_Severity"].copy()
_, x_sample, _, y_sample = train_test_split(rev_cas_sample[rev_cas_sample.columns[rev_cas_sample.columns.isin(sel_cols)]].drop("Accident_Severity", axis=1), rev_cas_sample["Accident_Severity"],
test_size=0.2, stratify=rev_cas_sample["Accident_Severity"])
g = sns.pairplot(x_sample.join(y_sample)[sel_cols], hue="Accident_Severity")
new_labels = ['Fatal', 'Severe', 'Slight']
for t, l in zip(g._legend.texts, new_labels): t.set_text(l)
# import model use data
tap = iacc.merge(icas, on="Accident_Index", how="inner").merge(iveh, on="Accident_Index", how="inner")
all_train_columns = np.invert(tap.columns.isin([
'Accident_Index', 'Accident_Severity', 'Date_Time', 'Year'
]))
freq_columns = ['Number_of_Vehicles', 'Speed_Limit', 'Age_Band_of_Casualty', 'Number_of_Casualties']
cat_columns = tap.columns[all_train_columns].drop(freq_columns)
cas_columns = ['Casualty_Class', 'Pedestrian_Location', 'Pedestrian_Movement', 'Car_Passenger', 'Bus_or_Coach_Passenger', 'Casualty_Type', 'Sex_of_Casualty', 'Age_Band_of_Casualty', 'Number_of_Casualties']
rs_casres = pd.read_csv("./search/rf/randomsearch_rf.cas.v2.csv").sort_values("rank_test_f1_macro")
rs_nocasres = pd.read_csv("./search/rf/randomsearch_rf.nocas.v2.csv").sort_values("rank_test_f1_macro")
# Metrics
rs_casres[[
'mean_test_accuracy',
'mean_test_f1_macro', 'mean_test_precision_macro',
'mean_test_recall_macro', 'param_max_depth', 'param_max_features',
'param_min_samples_leaf', 'param_n_estimators'
]]
# Metrics
rs_nocasres[[
'mean_test_accuracy',
'mean_test_f1_macro', 'mean_test_precision_macro',
'mean_test_recall_macro', 'param_max_depth', 'param_max_features',
'param_min_samples_leaf', 'param_n_estimators'
]]
nn_casres = pd.read_csv("./search/nn/randomsearch_nn.cas.v2.csv").sort_values("rank_test_f1_macro")
nn_nocasres = pd.read_csv("./search/nn/randomsearch_nn.nocas.v2.csv").sort_values("rank_test_f1_macro")
# Metrics
nn_nocasres[[
'mean_score_time', 'mean_test_accuracy',
'mean_test_f1_macro', 'mean_test_precision_macro',
'mean_test_recall_macro',
'param_batch_size', 'param_drop_rate',
'param_epochs', 'param_kernel_init', 'param_optimizer', 'param_units',
'param_units_2'
]]
# Metrics
nn_casres[[
'mean_score_time', 'mean_test_accuracy',
'mean_test_f1_macro', 'mean_test_precision_macro',
'mean_test_recall_macro',
'param_batch_size', 'param_drop_rate',
'param_epochs', 'param_kernel_init', 'param_optimizer', 'param_units',
'param_units_2'
]]
test_data = joblib.load("./test/stratified_traintest.pkl.xz")
rf_cas = joblib.load("./final/rf.final.cas.v2.pkl.xz")
train_cas = pd.read_csv("./train/stratified_XY_train.oh.tlsmote.cas.v2.csv")
rf_nocas = joblib.load("./final/rf.final.nocas.v2.pkl.xz")
train_nocas = pd.read_csv("./train/stratified_XY_train.oh.tlsmote.nocas.v2.csv")
test_data_oh = joblib.load("./test/stratified_traintest.oh.pkl.xz")
nocas_oob = joblib.load("./rf_oob/rf_oob-score.nocas.v2.pkl.xz")
cas_oob = joblib.load("./rf_oob/rf_oob-score.cas.v2.pkl.xz")
fig = plt.figure(figsize=[20, 8])
ax = fig.add_subplot(121)
plot_rf_oob(nocas_oob, ax=ax, title="OOB Error Rate against Number of Estimators for No Casualty Features")
ax = fig.add_subplot(122)
plot_rf_oob(cas_oob, ax=ax, title="OOB Error Rate against Number of Estimators for Casualty Inclusive Features")
plt.show()
fig = plt.figure(figsize=[15, 15])
ax = fig.add_subplot(121)
im = plot_confusion_matrix(cm=rf_nocas["result"]["confusion_matrix"], ax=ax,
classes=["Fatal", "Severe", "Slight"], cmap=plt.cm.Oranges,
title="Random Forest Confusion Matrix for No Casualty Features")
ax = fig.add_subplot(122)
im = plot_confusion_matrix(cm=rf_cas["result"]["confusion_matrix"], ax=ax,
classes=["Fatal", "Severe", "Slight"], cmap=plt.cm.Oranges,
title="Random Forest Confusion Matrix for Casualty Inclusive Features")
plt.show()
rf_roc_nocas = joblib.load("./final/roc_auc_score-all.oh.tlsmote.nocas.v2.rf.pkl.xz")
rf_roc_cas = joblib.load("./final/roc_auc_score-all.oh.tlsmote.cas.v2.rf.pkl.xz")
fig = plt.figure(figsize=[20, 8])
ax = fig.add_subplot(121)
plot_roc_auc_curve(rf_roc_nocas, ax=ax, title="Random Forest ROC Curve for No Casualty Features")
ax = fig.add_subplot(122)
plot_roc_auc_curve(rf_roc_cas , ax=ax, title="Random Forest ROC Curve for Casualty Inclusive Features")
plt.show()
fig = plt.figure(figsize=[25, 5])
ax = fig.add_subplot(121)
fi_cas = remap_feature_importances(train_nocas[sample_cols["nocascols"]], rf_nocas["fi"])
xgb.plot_importance(fi_cas, ax=ax, show_values=False, xlabel="Gini Impurity Decreased",
title="Feature Importance for No Casualty Features")
ax.set_facecolor("white")
ax = fig.add_subplot(122)
fi_cas = remap_feature_importances(sample_cols["cascols"], rf_cas["fi"])
xgb.plot_importance(fi_cas, ax=ax, show_values=False, xlabel="Gini Impurity Decreased",
title="Feature Importance for Casualty Inclusive Features")
ax.set_facecolor("white")
plt.show()
met.parse_evaluate_model_print(rf_nocas["result"])
met.parse_evaluate_model_print(rf_cas["result"])
nn_cas = joblib.load("./final/nn.val.final.cas.v2.pkl.xz")
nn_model_cas = load_model("./final/nn.model.val.final.cas.v2.h5")
nn_nocas = joblib.load("./final/nn.val.final.nocas.v2.pkl.xz")
nn_model_nocas = load_model("./final/nn.model.val.final.nocas.v2.h5")
epochs = np.arange(1, 201, 1)
# For odd reasons, this cannot be viewed in GitHub directly.
# To view two images in github, simply right click on the file like icon and view image on another tab
display(HTML("<table><tr><td><img src='./nn_model_nocas.png'></td><td><img src='./nn_model_cas.png'></td></tr></table>"))
fig = plt.figure(figsize=[15, 5])
ind = 121
ax = fig.add_subplot(ind)
x, y = simple_line_numeric(epochs, nn_nocas["history"]["categorical_accuracy"])
l1 = ax.plot(x, y, label="Training Accuracy")
x, y = simple_line_numeric(epochs, nn_nocas["history"]["val_categorical_accuracy"])
l2 = ax.plot(x, y, label="Validation Accuracy")
ax.legend(loc="center right")
ax.set_xlabel("Epochs")
ax.set_ylabel("Accuracy")
ind += 1
ax = fig.add_subplot(ind)
x, y = simple_line_numeric(epochs, nn_nocas["history"]["loss"])
l1 = ax.plot(x, y, label="Training Loss")
x, y = simple_line_numeric(epochs, nn_nocas["history"]["val_loss"])
l2 = ax.plot(x, y, label="Validation Loss")
ax.legend(loc="center right")
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")
ind += 1
fig.suptitle("Artificial Neural Network Training for No Casualty Features")
plt.show()
fig = plt.figure(figsize=[15, 5])
ind = 121
ax = fig.add_subplot(ind)
x, y = simple_line_numeric(epochs, nn_cas["history"]["categorical_accuracy"])
l1 = ax.plot(x, y, label="Training Accuracy")
x, y = simple_line_numeric(epochs, nn_cas["history"]["val_categorical_accuracy"])
l2 = ax.plot(x, y, label="Validation Accuracy")
ax.legend(loc="center right")
ax.set_xlabel("Epochs")
ax.set_ylabel("Accuracy")
ind += 1
ax = fig.add_subplot(ind)
x, y = simple_line_numeric(epochs, nn_cas["history"]["loss"])
l1 = ax.plot(x, y, label="Training Loss")
x, y = simple_line_numeric(epochs, nn_cas["history"]["val_loss"])
l2 = ax.plot(x, y, label="Validation Loss")
ax.legend(loc="center right")
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")
ind += 1
fig.suptitle("Artificial Neural Network Training for Casualty Inclusive Features")
plt.show()
fig = plt.figure(figsize=[15, 15])
ax = fig.add_subplot(121)
im = plot_confusion_matrix(cm=nn_nocas["result"]["confusion_matrix"], ax=ax,
classes=["Fatal", "Severe", "Slight"], cmap=plt.cm.Oranges,
title="Artificial Neural Network Confusion Matrix for No Casualty Features")
ax = fig.add_subplot(122)
im = plot_confusion_matrix(cm=nn_cas["result"]["confusion_matrix"], ax=ax,
classes=["Fatal", "Severe", "Slight"], cmap=plt.cm.Oranges,
title="Artificial Neural Network Confusion Matrix for Casualty Inclusive Features")
plt.show()
nn_roc_nocas = joblib.load("./final/roc_auc_score-all.oh.tlsmote.nocas.v2.nn.pkl.xz")
nn_roc_cas = joblib.load("./final/roc_auc_score-all.oh.tlsmote.cas.v2.nn.pkl.xz")
fig = plt.figure(figsize=[20, 8])
ax = fig.add_subplot(121)
plot_roc_auc_curve(nn_roc_nocas, ax=ax, title="Artificial Neural Network ROC Curve for No Casualty Features")
ax = fig.add_subplot(122)
plot_roc_auc_curve(nn_roc_cas , ax=ax, title="Artificial Neural Network ROC Curve for Casualty Inclusive Features")
plt.show()
met.parse_evaluate_model_print(nn_nocas["result"])
met.parse_evaluate_model_print(nn_cas["result"])